'''
This is a pseudo-code to help you understand the paper.
The entire source code is planned to be released to public.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


    
def PositionalEncoding(d_model, lengths, w_s=None):
    L = int(lengths.max().item())
    if w_s is None:
        position = torch.arange(0, L, dtype=torch.float).unsqueeze(0).to(lengths.device)
    else:
        position = torch.arange(0, L, dtype=torch.float).unsqueeze(0).to(lengths.device) * w_s.unsqueeze(-1)
    div_term = torch.pow(10000, torch.arange(0, d_model, 2).float() / d_model).to(lengths.device)
    pe = torch.zeros(len(lengths), L, d_model).to(lengths.device)
    
    pe[:, :, 0::2] = torch.sin(position.unsqueeze(-1) / div_term.unsqueeze(0))
    pe[:, :, 1::2] = torch.cos(position.unsqueeze(-1) / div_term.unsqueeze(0))
    return pe



def get_mask_from_lengths(lengths):
    max_len = torch.max(lengths).item()
    ids = lengths.new_tensor(torch.arange(0, max_len))
    mask = (lengths.unsqueeze(1) <= ids).to(torch.bool)
    return mask.detach()